import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
import numpy as np
import math
from matplotlib import cm
import time

def wait(key):
    while True:
        choice = input("Press Enter to continue... ")
        if choice == key :
            break

class ArtificialField:
    
    def __init__(self,width,height,gridSize,robot_location,target_location,method):
    # 初始化
        self.Width = width  # 宽度的栅格数
        self.Height = height
        self.gridSize = gridSize    # 栅格的大小
        self.grid = None    # 栅格空间 类型 np.zeros(width,height)
        self.obs = None     # 障碍物坐标
        self.Field = None   # 势场空间
        self.field_array = None
        self.robotLoc = robot_location  # 机器人起始在栅格的位置
        self.targetLoc = target_location    # 目标点位
        self.ENABLE_DRAW = True     # 是否允许绘制动图
        self.creatGrid()    # 创建栅格空间
        self.creatObstacle()    # 创建障碍物坐标，并绘制
        self.writeObstacle()    # 
        self.creatField(method)

        self.drawField()
        # self.planning() # 规划轨迹

        # np.set_printoptions(threshold=100000) # 设置打印阈值
        # print(self.field_array)     
        


                


    def creatGrid(self):
        # 创建栅格
        self.grid = np.zeros(((int)(self.Height), (int)(self.Width)))   # 初始化栅格空间
        self.field_array = self.grid.copy()
        plt.figure("map")
        plt.axis('scaled')  # 横纵坐标大小刻度一致     
        axis = plt.gca()    # 获取坐标轴的点
        xAxis = np.arange(0,(self.Width+1)*self.gridSize,self.gridSize) # 设置栅格
        yAxis = np.arange(0,(self.Height+1)*self.gridSize,self.gridSize)
        axis.set_xticks(xAxis)
        axis.set_yticks(yAxis)
        axis.tick_params(axis = 'both', direction = 'in')
        plt.grid(True)  # 绘制栅格

        # 绘制机器人坐标 和 终点坐标
        self.drawPoint(self.robotLoc,markersize = 6, marker = 'o', markerfacecolor = 'red', markeredgecolor = 'red')
        self.drawPoint(self.targetLoc,markersize = 6, marker = 'v', markerfacecolor = 'green', markeredgecolor = 'green')
        

    def creatObstacle(self):
        # 创建障碍物 
        self.obs = []

        # 含有边界的栅格
        obs1 = [(0,a) for a in range((int)(self.Height))]
        obs2 = [((int)(self.Width) - 1,a) for a in range((int)(self.Height))]
        obs3 = [(a,0) for a in range((int)(self.Width))]
        obs4 = [(a,(int)(self.Height) - 1) for a in range((int)(self.Width))]
        # obs5 = [(13,a) for a in range(20)]
        # obs6 = [(26,a + 20) for a in range(20)]
        # obs7 = [(a + 26,20) for a in range(8)]
        # obs8 = [(34,6+a) for a in range(15)]
        
        obs5 = [(a,b) for a in range(10,14,1) for b in range(10,14,1)]
        obs6 = [(a,b) for a in range(22,24,1) for b in range(12,16,1)]
        obs7 = [(a,b) for a in range(20,28,1) for b in range(20,24,1)]
        obs8 = [(a,b) for a in range(28,32,1) for b in range(28,32,1)]
        obs9 = [(a,b) for a in range(20,24,1) for b in range(28,32,1)]

        # 含有局部极小值的障碍物
        obs10 = [(8,b) for b in range(0,9,1)]
        obs11 = [(9,b) for b in range(0,9,1)]
        obs12 = [(a,8) for a in range(3,8,1)]
        obs13 = [(a,9) for a in range(3,8,1)]

        # 上次Astar使用障碍物 
        '''
        obs1 = [(0,a) for a in range((int)(self.Height))]
        obs2 = [((int)(self.Width) - 1,a) for a in range((int)(self.Height))]
        obs3 = [(a,0) for a in range((int)(self.Width))]
        obs4 = [(a,(int)(self.Height) - 1) for a in range((int)(self.Width))]
        obs5 = [(13,a) for a in range(20)]
        obs6 = [(26,a + 20) for a in range(20)]
        obs7 = [(a + 26,20) for a in range(8)]
        obs8 = [(34,6+a) for a in range(15)]
        '''

        # 创建障碍物
        self.obs = self.obs + obs5 + obs6  + obs7 + obs8 + obs9 + obs10 + obs11 + obs12 + obs13#  + obs1 + obs2 + obs3 + obs4 

        # 绘制障碍物
        for item in list(self.obs):
            self.drawPoint(item, markersize = 4,marker = 's', markerfacecolor = 'black', markeredgecolor = 'black')
        if self.ENABLE_DRAW:
            # plt.show()
            plt.ioff()
            plt.gcf().show()

    def writeObstacle(self):
        # 将障碍物 写进 栅格（相同的坐标置 1）
        for item in list(self.obs):
            self.grid[item[0],item[1]] = 1;

    def creatField(self,method):
        # 创建势场
        the_max = 0 # 最大的斥力值
        for i in range(self.Width):
            for j in range(self.Height):
                if (i,j) in self.obs:
                    continue
                else:
                    repu = self.repulsive((i,j),self.obs,method,1.5,2) # 计算斥力
                    the_max =  max(the_max,repu)    # 比较
                    self.field_array[i][j] = self.gravitate((i,j),self.targetLoc,method,0.2) + repu # 计算势场的值
        # 将最大的斥力值写进障碍物的位置                             
        for item in self.obs:
            self.field_array[item[0]][item[1]] = the_max + self.gravitate((item[0],item[1]),self.targetLoc,method,0.3)  

    def gravitate(self,pos_now,pos_end,method,param):
        # 计算终点对当前点的引力
        if method == 'distance':
            return param * math.hypot(pos_now[0]-pos_end[0], pos_now[1] - pos_end[1])/2

    def repulsive(self,pos_now,obs_map,method,param,threshold):
        # 计算障碍物对当前点的斥力
        rep = 0
        if method == 'distance':
            for obs_pos in obs_map:
                dis_squ = math.hypot(pos_now[0]-obs_pos[0], pos_now[1] - obs_pos[1])
                if dis_squ == 0:
                    print("Error: divide 0")
                    return
                if  dis_squ <= threshold:
                    rep  = rep + 0.5 * param * math.pow((1/math.sqrt(dis_squ) - 1/threshold),2)
        return rep

    def drawField(self):

        # 绘制三维势场图
        
        fig = plt.figure("Field")
        ax = fig.add_subplot(111, projection='3d')

        X = np.arange(0,self.Width,1)
        Y = np.arange(0,self.Height,1)

        X,Y = np.meshgrid(X, Y)
        Z = self.field_array
        # print(Z)
        # Plot a basic wireframe.
        surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm,
                      linewidth=0, antialiased=False)
        plt.ioff()
        plt.gcf().show()
       
    def drawPoint(self,loction,markersize = 4,marker = 's',markerfacecolor = 'black',markeredgecolor = 'black'):
        # 将点绘制进栅格（因为栅格用刻度表示，所以绘制点时需要加上偏置）
        plt.plot(loction[0]*self.gridSize + self.gridSize/2,loction[1]*self.gridSize + self.gridSize/2,
                 linestyle = '',
                 markersize = markersize,
                 marker = marker,
                 markerfacecolor = markerfacecolor,
                 markeredgecolor = markeredgecolor)
     

    
    
    def planning(self):
    # 路径规划函数

        currentPos = self.robotLoc
        plt.figure("map") 
        path = [currentPos]
        index = 0
        while currentPos != self.targetLoc:
            grad = 0
            nextPos = (0,0)
            for move in self.get_motion_model():
                temp_nextPos = (currentPos[0] + move[0], currentPos[1] + move[1])
                if temp_nextPos[0] > self.Width or temp_nextPos[1] > self.Height or temp_nextPos[0] < 0 or temp_nextPos[1] < 0:
                    continue
                distance = move[2]

                # 梯度值计算
                temp_grad = (self.field_array[temp_nextPos[0]] [temp_nextPos[1]] - self.field_array[currentPos[0]] [currentPos[1]]) / distance # 梯度下降
                if temp_grad <= grad:
                    grad = temp_grad
                    nextPos = temp_nextPos

            # 如果是极小值点
            if grad == 0:
                self.field_array[currentPos[0]] [currentPos[1]] = 20
                currentPos = path[index - 1]
                continue
            currentPos = nextPos
            path.append(currentPos)
            
            index = index + 1
            self.drawPoint(currentPos,marker = 'o',markerfacecolor = 'blue',markeredgecolor = 'blue')
            if (len(path)%5 == 0):
                plt.pause(0.001)

        # 绘制路径
        x , y = [], []
        for item in path:
            x.append(item[0]*self.gridSize + self.gridSize/2)
            y.append(item[1]*self.gridSize + self.gridSize/2)
        plt.plot(x,y,color = 'red',linewidth = 2)
        plt.show()
        print(path)
        return
    
    
    @staticmethod
    def get_motion_model():
        # dx, dy, cost
        motion = [[1, 0, 1],
                  [0, 1, 1],
                  [-1, 0, 1],
                  [0, -1, 1],
                  [-1, -1, math.sqrt(2)],
                  [-1, 1, math.sqrt(2)],
                  [1, -1, math.sqrt(2)],
                  [1, 1, math.sqrt(2)]]
        return motion



def main():
    start_pos = (0,0)
    end_pos = (35,35)
    grid_width = 40
    grid_height = 40
    min_grid = 5

    artFi = ArtificialField(grid_width,grid_height,min_grid,start_pos,end_pos,'distance')
    # time.sleep(5)
    artFi.planning()
    wait('b')
    return
    
if __name__ == '__main__':
    main()